{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exporting MXNet models to ONNX\n",
    "\n",
    "In this tutorial, we will show how you can save MXNet models to ONNX format.\n",
    "ONNX exporter is a part of [MXNet repository](https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/contrib/onnx/mx2onnx).\n",
    "\n",
    "Current MXNet-ONNX import and export operator support and coverage can be found [here](Operator support and coverage - https://cwiki.apache.org/confluence/display/MXNET/ONNX):\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Installations:\n",
    "#### Install **ONNX 1.2.1 version.** Follow instructions on [ONNX repo](https://github.com/onnx/onnx).\n",
    "\n",
    "MXNet ONNX importer and exporter follows version 7 of ONNX operator set.\n",
    "\n",
    "#### Make sure to install latest MXNet either from source or pip.\n",
    "\n",
    "```bash\n",
    "pip install mxnet --pre\n",
    "```\n",
    "For windows, type python -m pip install mxnet --pre in your cmd.exe window, make sure you have installed python properly",
    "\n",
    "* Note: ONNX exporter will be released as a part of MXNet v1.3. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Prepare MXNet model to convert to ONNX:\n",
    "\n",
    "Let's try out pretrained resnet model from [MXNet model zoo](http://data.mxnet.io/models/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "import numpy as np\n",
    "from mxnet.contrib import onnx as onnx_mxnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['resnet-18-0000.params', 'resnet-18-symbol.json', 'synset.txt']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Download pretrained resnet model - json and params from mxnet model zoo.\n",
    "path='http://data.mxnet.io/models/imagenet/'\n",
    "[mx.test_utils.download(path+'resnet/18-layers/resnet-18-0000.params'),\n",
    " mx.test_utils.download(path+'resnet/18-layers/resnet-18-symbol.json'),\n",
    " mx.test_utils.download(path+'synset.txt')]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Use MXNet to ONNX exporter:\n",
    "\n",
    "MXNet's ONNX \"export_model\" API accepts following inputs: \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Help on function export_model in module mxnet.contrib.onnx.mx2onnx.export_model:\n",
      "\n",
      "export_model(sym, params, input_shape, input_type=<type 'numpy.float32'>, onnx_file_path=u'model.onnx', verbose=False)\n",
      "    Exports the MXNet model file, passed as a parameter, into ONNX model.\n",
      "    Accepts both symbol,parameter objects as well as json and params filepaths as input.\n",
      "    Operator support and coverage - https://cwiki.apache.org/confluence/display/MXNET/ONNX\n",
      "    \n",
      "    Parameters\n",
      "    ----------\n",
      "    sym : str or symbol object\n",
      "        Path to the json file or Symbol object\n",
      "    params : str or symbol object\n",
      "        Path to the params file or params dictionary. (Including both arg_params and aux_params)\n",
      "    input_shape : List of tuple\n",
      "        Input shape of the model e.g [(1,3,224,224)]\n",
      "    input_type : data type\n",
      "        Input data type e.g. np.float32\n",
      "    onnx_file_path : str\n",
      "        Path where to save the generated onnx file\n",
      "    verbose : Boolean\n",
      "        If true will print logs of the model conversion\n",
      "    \n",
      "    Returns\n",
      "    -------\n",
      "    onnx_file_path : str\n",
      "        Onnx file path\n",
      "\n"
     ]
    }
   ],
   "source": [
    "help(onnx_mxnet.export_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the API description, you can see that `export_model` API accepts 2 kinds of inputs:\n",
    "\n",
    "### MXNet sym, params objects:\n",
    "\n",
    "    This is useful if you are training a model. At the end of training, you just need to invoke the `export_model` function and provide sym and params objects as inputs with other attributes to save the model in ONNX format.\n",
    "    \n",
    "### MXNet's exported json and params files:\n",
    "\n",
    "    This is useful if you have pretrained models and you want to convert them to ONNX format.\n",
    "\n",
    "In this tutorial, we will show second usecase: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Downloaded input symbol and params files\n",
    "sym = 'resnet-18-symbol.json'\n",
    "params = 'resnet-18-0000.params'\n",
    "# Standard Imagenet input - 3 channels, 224*224\n",
    "input_shape = (1,3,224,224)\n",
    "# Path of the output file\n",
    "onnx_file = 'mxnet_exported_resnet50.onnx'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Invoke export model API. It returns path of the converted onnx model\n",
    "converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], np.float32, onnx_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mxnet_exported_resnet50.onnx\n"
     ]
    }
   ],
   "source": [
    "print(converted_model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Check validity\n",
    "\n",
    "You can check validity of the converted ONNX model by using ONNX checker tool as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnx import checker\n",
    "import onnx\n",
    "# Load onnx model\n",
    "model_proto = onnx.load(converted_model_path)\n",
    "\n",
    "# Check if converted ONNX protobuf is valid\n",
    "checker.check_graph(model_proto.graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This method confirms input model protobuf is valid."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What's next\n",
    "\n",
    "Take a look at [other tutorials, including importing of ONNX models to other frameworks](https://github.com/onnx/tutorials/tree/master/tutorials)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "onnx_mxnet",
   "language": "python",
   "name": "onnx_mxnet"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
